{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Sequence to sequence model (S2S)\n",
    "(c) Deniz Yuret, 2018-2020.\n",
    "\n",
    "Based on ([Sutskever et al. 2014](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf)).\n",
    "\n",
    "S2S models learn to map input sequences to output sequences using an encoder and a decoder RNN. Note that this is an instructional example written in low-level Julia/Knet and it is slow to train. For a faster and high-level implementation please see `@doc RNN`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "using Knet, CUDA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "<img src=\"images/seq2seq.png\"/>(<a href=\"https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf\">image source</a>)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# S2S model definition\n",
    "\n",
    "function initmodel(H, V; atype=(CUDA.functional() ? KnetArray{Float32} : Array{Float32}))\n",
    "    init(d...)=atype(xavier(d...))\n",
    "    model = Dict{Symbol,Any}()\n",
    "    model[:state0] = [ init(1,H), init(1,H) ]\n",
    "    model[:embed1] = init(V,H)\n",
    "    model[:encode] = [ init(2H,4H), init(1,4H) ]\n",
    "    model[:embed2] = init(V,H)\n",
    "    model[:decode] = [ init(2H,4H), init(1,4H) ]\n",
    "    model[:output] = [ init(H,V), init(1,V) ]\n",
    "    return model\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# S2S loss function and its gradient\n",
    "\n",
    "function s2s(model, inputs, outputs)\n",
    "    state = initstate(inputs[1], model[:state0])\n",
    "    for input in inputs\n",
    "        input = onehotrows(input, model[:embed1])\n",
    "        input = input * model[:embed1]\n",
    "        state = lstm(model[:encode], state, input)\n",
    "    end\n",
    "    EOS = eosmatrix(outputs[1], model[:embed2])\n",
    "    input = EOS * model[:embed2]\n",
    "    sumlogp = 0\n",
    "    for output in outputs\n",
    "        state = lstm(model[:decode], state, input)\n",
    "        ypred = predict(model[:output], state[1])\n",
    "        ygold = onehotrows(output, model[:embed2])\n",
    "        sumlogp += sum(ygold .* logp(ypred,dims=2))\n",
    "        input = ygold * model[:embed2]\n",
    "    end\n",
    "    state = lstm(model[:decode], state, input)\n",
    "    ypred = predict(model[:output], state[1])\n",
    "    sumlogp += sum(EOS .* logp(ypred,dims=2))\n",
    "    return -sumlogp\n",
    "end\n",
    "\n",
    "s2sgrad = gradloss(s2s);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "<img src=\"images/s2s-dims.png\"/>(<a href=\"https://docs.google.com/drawings/d/1BR871g8k4jpI-mKeXiJfpY5Jl5cKcognvH7hHSugQds/edit?usp=sharing\">image source</a>)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# A LSTM implementation in Knet\n",
    "\n",
    "function lstm(param, state, input)\n",
    "    weight,bias = param\n",
    "    hidden,cell = state\n",
    "    h       = size(hidden,2)\n",
    "    gates   = hcat(input,hidden) * weight .+ bias\n",
    "    forget  = sigm.(gates[:,1:h])\n",
    "    ingate  = sigm.(gates[:,1+h:2h])\n",
    "    outgate = sigm.(gates[:,1+2h:3h])\n",
    "    change  = tanh.(gates[:,1+3h:4h])\n",
    "    cell    = cell .* forget + ingate .* change\n",
    "    hidden  = outgate .* tanh.(cell)\n",
    "    return (hidden,cell)\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# S2S helper functions\n",
    "\n",
    "function predict(param, input)\n",
    "    input * param[1] .+ param[2]\n",
    "end\n",
    "\n",
    "function initstate(idx, state0)\n",
    "    h,c = state0\n",
    "    h = h .+ fill!(similar(value(h), length(idx), length(h)), 0)\n",
    "    c = c .+ fill!(similar(value(c), length(idx), length(c)), 0)\n",
    "    return (h,c)\n",
    "end\n",
    "\n",
    "function onehotrows(idx, embeddings)\n",
    "    nrows,ncols = length(idx), size(embeddings,1)\n",
    "    z = zeros(Float32,nrows,ncols)\n",
    "    @inbounds for i=1:nrows\n",
    "        z[i,idx[i]] = 1\n",
    "    end\n",
    "    oftype(value(embeddings),z)\n",
    "end\n",
    "\n",
    "let EOS=nothing; global eosmatrix\n",
    "function eosmatrix(idx, embeddings)\n",
    "    nrows,ncols = length(idx), size(embeddings,1)\n",
    "    if EOS==nothing || size(EOS) != (nrows,ncols)\n",
    "        EOS = zeros(Float32,nrows,ncols)\n",
    "        EOS[:,1] .= 1\n",
    "        EOS = oftype(value(embeddings), EOS)\n",
    "    end\n",
    "    return EOS\n",
    "end\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Use reversing English words as an example task\n",
    "# This loads them from /usr/share/dict/words and converts each character to an int.\n",
    "\n",
    "function readdata(file=\"words\")\n",
    "    isfile(file) || (file=download(\"http://people.csail.mit.edu/deniz/models/tutorial/words\",\"words\"))\n",
    "    global strings = map(chomp,readlines(file))\n",
    "    global tok2int = Dict{Char,Int}()\n",
    "    global int2tok = Vector{Char}()\n",
    "    push!(int2tok,'\\n'); tok2int['\\n']=1 # We use '\\n'=>1 as the EOS token                                                 \n",
    "    sequences = Vector{Vector{Int}}()\n",
    "    for w in strings\n",
    "        s = Vector{Int}()\n",
    "        for c in collect(w)\n",
    "            if !haskey(tok2int,c)\n",
    "                push!(int2tok,c)\n",
    "                tok2int[c] = length(int2tok)\n",
    "            end\n",
    "            push!(s, tok2int[c])\n",
    "        end\n",
    "        push!(sequences, s)\n",
    "    end\n",
    "    return sequences\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "102305-element Array{Array{Int64,1},1}\n",
      "102305-element Array{SubString{String},1}\n",
      "70-element Array{Char,1}\n",
      "Dict{Char,Int64} with 70 entries\n",
      "Alpert\n",
      "Alpert's\n",
      "Alphard\n",
      "Alphard's\n",
      "Alphecca\n"
     ]
    }
   ],
   "source": [
    "sequences = readdata();\n",
    "for x in (sequences, strings, int2tok, tok2int); println(summary(x)); end\n",
    "for x in strings[501:505]; println(x); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"792-element Array{Any,1}\""
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Minibatch sequences putting equal length sequences together:\n",
    "\n",
    "function minibatch(sequences, batchsize)\n",
    "    table = Dict{Int,Vector{Vector{Int}}}()\n",
    "    data = Any[]\n",
    "    for s in sequences\n",
    "        n = length(s)\n",
    "        nsequences = get!(table, n, Any[])\n",
    "        push!(nsequences, s)\n",
    "        if length(nsequences) == batchsize\n",
    "            push!(data, [[ nsequences[i][j] for i in 1:batchsize] for j in 1:n ])\n",
    "            empty!(nsequences)\n",
    "        end\n",
    "    end\n",
    "    return data\n",
    "end\n",
    "\n",
    "batchsize, statesize, vocabsize = 128, 128, length(int2tok)\n",
    "data = minibatch(sequences,batchsize)\n",
    "summary(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "train (generic function with 1 method)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Training loop\n",
    "\n",
    "function train(model, data, opts)\n",
    "    sumloss = cntloss = 0\n",
    "    for sequence in data\n",
    "        grads,loss = s2sgrad(model, sequence, reverse(sequence))\n",
    "        update!(model, grads, opts)\n",
    "        sumloss += loss\n",
    "        cntloss += (1+length(sequence)) * length(sequence[1])\n",
    "    end\n",
    "    return sumloss/cntloss\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train from scratch? stdin> n\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"Dict{Symbol,Any} with 6 entries\""
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file = \"rnnreverse134.jld2\"; model = opts = nothing; GC.gc(true) # clean memory from previous run\n",
    "if (print(\"Train from scratch? \");readline()[1]=='y')\n",
    "    # Initialize model and optimization parameters\n",
    "    model = initmodel(statesize,vocabsize)\n",
    "    opts = optimizers(model,Adam)\n",
    "    @time for epoch=1:10\n",
    "        @time loss = train(model,data,opts) # ~17 sec/epoch\n",
    "        println((epoch,loss))\n",
    "    end\n",
    "    Knet.save(file,\"model\",model)\n",
    "else\n",
    "    isfile(file) || download(\"http://people.csail.mit.edu/deniz/models/tutorial/$file\",file)\n",
    "    model = Knet.load(file,\"model\")\n",
    "end\n",
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Test on some examples:\n",
    "\n",
    "function translate(model, str)\n",
    "    state = model[:state0]\n",
    "    for c in collect(str)\n",
    "        input = onehotrows(tok2int[c], model[:embed1])\n",
    "        input = input * model[:embed1]\n",
    "        state = lstm(model[:encode], state, input)\n",
    "    end\n",
    "    input = eosmatrix(1, model[:embed2]) * model[:embed2]\n",
    "    output = Char[]\n",
    "    for i=1:100 #while true                                                                                                \n",
    "        state = lstm(model[:decode], state, input)\n",
    "        pred = predict(model[:output], state[1])\n",
    "        i = argmax(vec(Array(pred)))\n",
    "        i == 1 && break\n",
    "        push!(output, int2tok[i])\n",
    "        input = onehotrows(i, model[:embed2]) * model[:embed2]\n",
    "    end\n",
    "    String(output)\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"nrocirpac\""
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translate(model,\"capricorn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "julia.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Julia 1.5.0",
   "language": "julia",
   "name": "julia-1.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
